# ------------------------------------------------------------------------------
# Tunes GBMs
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 03_tune_gbms.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(xgboost)
library(optparse)
library(here) # here() relative filepaths
library(testit) # assert functtion

u <- modules::use(here("lib", "util.R"))
temp <- here("code", "08_train_split_model", "temp")

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
x <- readRDS(paths$features$train_tested)
ids <- readRDS(file.path(temp, "split_train_cohort.rds"))

# Subset Data ------------------------------------------------------------------
message("Subsetting data...")
keep_pop <- !ids$excl_flag_c_int & !ids$excl_flag_chronic & !ids$excl_flag_death
keep <- which(keep_pop)

x <- x[keep, ]
ids <- ids[keep, ]

tuning_design <- crossing(
  eta = 0.05, max_depth = c(7,8,9), subsample = c(0.75, 0.85),
  colsample_bytree = c(0.5), min_child_weight = c(1, 5, 10)
)

# Tuning -----------------------------------------------------------------------
message("Tuning GBMs...")
for(i in c(1,2)){
  message("Split ", i)
  split_var <- glue("sample_split{i}")
  keep_split <- which(ids[[split_var]])

  x_split <- x[keep_split,]
  ids_split <- ids[keep_split,]

  folds <- map(
    unique(ids_split$train_fold), ~ which(ids_split$train_fold == .x)
  )

  tuning_result <- pmap(tuning_design, xgb.cv,
    data = x_split,
    label = ids_split$stent_or_cabg_010_day,
    folds = folds,
    nrounds = 10000L,
    early_stopping_rounds = 20L,
    objective = "binary:logistic"
  )

  tuning_design$logloss <- map(tuning_result, ~ unlist(.x$evaluation_log))
  tuning_design$num_iterations <- map(tuning_design$logloss, seq_along)

  message("Saving model ", i, "...")
  save_fp <- file.path(
    temp,
    glue("tuning__gbm__stent_or_cabg_010_day__tested__{i}.rds")
  )
  saveRDS(tuning_design, save_fp)
}

# Done -------------------------------------------------------------------------
message("Done.")
